import pickle

with open('../models/narrative_model_1000_clusters.pk', 'rb') as f:
    narrative_model = pickle.load(f)

from narrativeNLP.clustering import (
    get_vectors,
    SIF_keyed_vectors
)

kmeans = narrative_model['cluster_model'][0][0]
sif = narrative_model['embeddings_model']

import pandas as pd

df = pd.read_csv('../data/gpo_final_data/narratives_complete_with_metadata_manual_labels_rich_1000.csv')

argO = df['ARGO'].value_counts().reset_index()
arg1 = df['ARG1'].value_counts().reset_index()
arg2 = df['ARG2'].value_counts().reset_index()

arg = argO.merge(arg1, on = 'index')
arg = arg.merge(arg2, on = 'index')

arg['total_count'] = arg['ARGO'] + arg['ARG1'] + arg['ARG2']
top_entities = list(arg.sort_values(by = 'total_count', ascending=False)['index'])[0:100]

#filter for named entities
top_entities = [entity for entity in top_entities if entity not in narrative_model['entities']]

# take top ten most frequent
examples = top_entities[0:5]

import plotly
import numpy as np
import plotly.graph_objs as go
from sklearn.decomposition import PCA

model = sif
user_input = examples
topn = 5
output_path = '../figures/Figure_A_1.html'

words = []

for example in examples:
    words.extend(list(df[df['ARGO'] == example]['ARGO-RAW'].value_counts().reset_index()['index'][0:5]))

split_words = [w.split() for w in words]

word_vectors = np.array([model(w) for w in split_words])

three_dim = PCA(random_state=0).fit_transform(word_vectors)[:, :2]

data = []
count = 0

for i in range(len(user_input)):
    trace = go.Scatter(
        x=three_dim[count:count + topn, 0],
        y=three_dim[count:count + topn, 1],
        #z=three_dim[count:count + topn, 2],
        text=words[count:count + topn],
        name=user_input[i],
        textposition="top center",
        mode='markers+text',
        marker={
            'size': 10,
            'opacity': 0.8,
            'color': 2
        }

    )

    data.append(trace)
    count = count + topn

trace_input = go.Scatter(
    x=three_dim[count:, 0],
    y=three_dim[count:, 1],
    #z=three_dim[count:, 2],
    text=words[count:],
    name='input words',
    textposition="top center",
    #textfont_size=10, #10
    mode='markers+text',
    marker={
        'size': 10,
        'opacity': 1,
        'color': 'black'
    }
)

data.append(trace_input)

layout = go.Layout(
    margin={'l': 0, 'r': 0, 'b': 0, 't': 0},
    showlegend=True,
    plot_bgcolor = 'white',
    legend=dict(
        x=1,
        y=0.5,
        font=dict(
            family="Courier New",
            size=25,
            color="black"
        )),
    font=dict(
        family=" Courier New ",
        size=25),
    autosize=False,
    width=1800,
    height=900
)

plot_figure = go.Figure(data=data, layout=layout)
plot_figure.write_html(output_path)
